{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Implements a Siamese/Y-Network using Functional API\n",
    "\n",
    "This is our first example of a network with a more complex graph. We call it Y-Network because it has a shape that is similar to the letter Y. There are two branches, left and right. Each one gets the same copy of input. Each branch processes the input and produces a different set of features. The left and right feature maps are the combined and passed to a head `Dense` layer for logistic regression. \n",
    "\n",
    "We use the same optimizer (`sgd`) and loss function (`categorical_crossentropy`). We train the network for 20 epochs.  \n",
    "\n",
    "~98.6% test accuracy"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 42,
   "metadata": {},
   "outputs": [],
   "source": [
    "from __future__ import absolute_import\n",
    "from __future__ import division\n",
    "from __future__ import print_function\n",
    "\n",
    "import numpy as np\n",
    "\n",
    "from tensorflow.keras.layers import Dense, Dropout, Input\n",
    "from tensorflow.keras.layers import Conv2D, MaxPooling2D, Flatten\n",
    "from tensorflow.keras.models import Model\n",
    "from tensorflow.keras.layers import concatenate\n",
    "from tensorflow.keras.datasets import mnist\n",
    "from tensorflow.keras.utils import to_categorical\n",
    "\n",
    "# load MNIST dataset\n",
    "(x_train, y_train), (x_test, y_test) = mnist.load_data()\n",
    "\n",
    "# from sparse label to categorical\n",
    "num_labels = len(np.unique(y_train))\n",
    "y_train = to_categorical(y_train)\n",
    "y_test = to_categorical(y_test)\n",
    "\n",
    "# reshape and normalize input images\n",
    "image_size = x_train.shape[1]\n",
    "x_train = np.reshape(x_train,[-1, image_size, image_size, 1])\n",
    "x_test = np.reshape(x_test,[-1, image_size, image_size, 1])\n",
    "x_train = x_train.astype('float32') / 255\n",
    "x_test = x_test.astype('float32') / 255\n",
    "\n",
    "# network parameters\n",
    "input_shape = (image_size, image_size, 1)\n",
    "batch_size = 128\n",
    "kernel_size = 3\n",
    "filters = 64"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Left Branch of a Y-Network\n",
    "\n",
    "The left branch is made of 3 layers of CNN with the configuration as the single branch CNN model example. To save in space, the left branch is constructed using a `for` loop. This technique and is used in constructing bigger models such as ResNet."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 43,
   "metadata": {},
   "outputs": [],
   "source": [
    "# left branch of Y network\n",
    "left_inputs = Input(shape=input_shape)\n",
    "x = left_inputs\n",
    "# 3 layers of Conv2D-MaxPooling2D\n",
    "depth = 3\n",
    "for i in range(depth):\n",
    "    x = Conv2D(filters=filters,\n",
    "               kernel_size=kernel_size,\n",
    "               padding='same',\n",
    "               activation='relu')(x)\n",
    "    #x = Dropout(dropout)(x)\n",
    "    if i < (depth - 1):\n",
    "        x = MaxPooling2D()(x)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Right Branch of a Y-Network\n",
    "\n",
    "The right branch is an exact mirror of the left branch. To ensure that it learns a different set of features, we use `dilation_rate = 2` to approximate a kernel with twice the size as the left brancg."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 44,
   "metadata": {},
   "outputs": [],
   "source": [
    "# right branch of Y network\n",
    "right_inputs = Input(shape=input_shape)\n",
    "y = right_inputs\n",
    "# 3 layers of Conv2D-Dropout-MaxPooling2D\n",
    "for i in range(depth):\n",
    "    y = Conv2D(filters=filters,\n",
    "               kernel_size=kernel_size,\n",
    "               padding='same',\n",
    "               activation='relu',\n",
    "               dilation_rate=2)(y)\n",
    "    #y = Dropout(dropout)(y)\n",
    "    if i < (depth - 1):\n",
    "        y = MaxPooling2D()(y)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Merging the 2 Branches\n",
    "\n",
    "To complete a Y-Network, let us merge the outputs of left and right branches. We use `concatenate()` which results to feature maps with the same dimension as left or right branch feature maps but with twice the number. There are other merging functions in Keras such as `add` and `multiply`."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 45,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Model: \"Y_Network\"\n",
      "__________________________________________________________________________________________________\n",
      "Layer (type)                    Output Shape         Param #     Connected to                     \n",
      "==================================================================================================\n",
      "input_17 (InputLayer)           [(None, 28, 28, 1)]  0                                            \n",
      "__________________________________________________________________________________________________\n",
      "input_18 (InputLayer)           [(None, 28, 28, 1)]  0                                            \n",
      "__________________________________________________________________________________________________\n",
      "conv2d_48 (Conv2D)              (None, 28, 28, 64)   640         input_17[0][0]                   \n",
      "__________________________________________________________________________________________________\n",
      "conv2d_51 (Conv2D)              (None, 28, 28, 64)   640         input_18[0][0]                   \n",
      "__________________________________________________________________________________________________\n",
      "max_pooling2d_44 (MaxPooling2D) (None, 14, 14, 64)   0           conv2d_48[0][0]                  \n",
      "__________________________________________________________________________________________________\n",
      "max_pooling2d_46 (MaxPooling2D) (None, 14, 14, 64)   0           conv2d_51[0][0]                  \n",
      "__________________________________________________________________________________________________\n",
      "conv2d_49 (Conv2D)              (None, 14, 14, 64)   36928       max_pooling2d_44[0][0]           \n",
      "__________________________________________________________________________________________________\n",
      "conv2d_52 (Conv2D)              (None, 14, 14, 64)   36928       max_pooling2d_46[0][0]           \n",
      "__________________________________________________________________________________________________\n",
      "max_pooling2d_45 (MaxPooling2D) (None, 7, 7, 64)     0           conv2d_49[0][0]                  \n",
      "__________________________________________________________________________________________________\n",
      "max_pooling2d_47 (MaxPooling2D) (None, 7, 7, 64)     0           conv2d_52[0][0]                  \n",
      "__________________________________________________________________________________________________\n",
      "conv2d_50 (Conv2D)              (None, 7, 7, 64)     36928       max_pooling2d_45[0][0]           \n",
      "__________________________________________________________________________________________________\n",
      "conv2d_53 (Conv2D)              (None, 7, 7, 64)     36928       max_pooling2d_47[0][0]           \n",
      "__________________________________________________________________________________________________\n",
      "concatenate_11 (Concatenate)    (None, 7, 7, 128)    0           conv2d_50[0][0]                  \n",
      "                                                                 conv2d_53[0][0]                  \n",
      "__________________________________________________________________________________________________\n",
      "flatten_5 (Flatten)             (None, 6272)         0           concatenate_11[0][0]             \n",
      "__________________________________________________________________________________________________\n",
      "dense_4 (Dense)                 (None, 10)           62730       flatten_5[0][0]                  \n",
      "==================================================================================================\n",
      "Total params: 211,722\n",
      "Trainable params: 211,722\n",
      "Non-trainable params: 0\n",
      "__________________________________________________________________________________________________\n"
     ]
    }
   ],
   "source": [
    "# merge left and right branches outputs\n",
    "y = concatenate([x, y])\n",
    "# feature maps to vector in preparation to connecting to Dense layer\n",
    "y = Flatten()(y)\n",
    "# y = Dropout(dropout)(y)\n",
    "outputs = Dense(num_labels, activation='softmax')(y)\n",
    "\n",
    "# build the model in functional API\n",
    "model = Model([left_inputs, right_inputs], outputs, name='Y_Network')\n",
    "# verify the model using graph\n",
    "# plot_model(model, to_file='cnn-y-network.png', show_shapes=True)\n",
    "# verify the model using layer text description\n",
    "model.summary()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Model Training and Validation\n",
    "\n",
    "This is just our usual model training and validation. Similar to our previous examples."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 46,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch 1/20\n",
      "469/469 [==============================] - 108s 231ms/step - loss: 1.5014 - accuracy: 0.6331 - val_loss: 0.4413 - val_accuracy: 0.8707\n",
      "Epoch 2/20\n",
      "469/469 [==============================] - 111s 238ms/step - loss: 0.3685 - accuracy: 0.8891 - val_loss: 0.3066 - val_accuracy: 0.9075\n",
      "Epoch 3/20\n",
      "469/469 [==============================] - 112s 239ms/step - loss: 0.2683 - accuracy: 0.9195 - val_loss: 0.2423 - val_accuracy: 0.9280\n",
      "Epoch 4/20\n",
      "469/469 [==============================] - 111s 237ms/step - loss: 0.2097 - accuracy: 0.9370 - val_loss: 0.1685 - val_accuracy: 0.9501\n",
      "Epoch 5/20\n",
      "469/469 [==============================] - 109s 233ms/step - loss: 0.1693 - accuracy: 0.9499 - val_loss: 0.1477 - val_accuracy: 0.9544\n",
      "Epoch 6/20\n",
      "469/469 [==============================] - 103s 219ms/step - loss: 0.1415 - accuracy: 0.9576 - val_loss: 0.1199 - val_accuracy: 0.9630\n",
      "Epoch 7/20\n",
      "469/469 [==============================] - 104s 222ms/step - loss: 0.1211 - accuracy: 0.9639 - val_loss: 0.1039 - val_accuracy: 0.9695\n",
      "Epoch 8/20\n",
      "469/469 [==============================] - 104s 221ms/step - loss: 0.1073 - accuracy: 0.9681 - val_loss: 0.0863 - val_accuracy: 0.9742\n",
      "Epoch 9/20\n",
      "469/469 [==============================] - 103s 220ms/step - loss: 0.0965 - accuracy: 0.9710 - val_loss: 0.0829 - val_accuracy: 0.9739\n",
      "Epoch 10/20\n",
      "469/469 [==============================] - 110s 235ms/step - loss: 0.0875 - accuracy: 0.9740 - val_loss: 0.0793 - val_accuracy: 0.9733\n",
      "Epoch 11/20\n",
      "469/469 [==============================] - 110s 235ms/step - loss: 0.0819 - accuracy: 0.9754 - val_loss: 0.0785 - val_accuracy: 0.9753\n",
      "Epoch 12/20\n",
      "469/469 [==============================] - 110s 235ms/step - loss: 0.0755 - accuracy: 0.9768 - val_loss: 0.0639 - val_accuracy: 0.9795\n",
      "Epoch 13/20\n",
      "469/469 [==============================] - 113s 241ms/step - loss: 0.0713 - accuracy: 0.9786 - val_loss: 0.0607 - val_accuracy: 0.9816\n",
      "Epoch 14/20\n",
      "469/469 [==============================] - 111s 237ms/step - loss: 0.0676 - accuracy: 0.9800 - val_loss: 0.0658 - val_accuracy: 0.9800\n",
      "Epoch 15/20\n",
      "469/469 [==============================] - 109s 231ms/step - loss: 0.0642 - accuracy: 0.9807 - val_loss: 0.0539 - val_accuracy: 0.9825\n",
      "Epoch 16/20\n",
      "469/469 [==============================] - 110s 234ms/step - loss: 0.0613 - accuracy: 0.9816 - val_loss: 0.0562 - val_accuracy: 0.9830\n",
      "Epoch 17/20\n",
      "469/469 [==============================] - 110s 235ms/step - loss: 0.0587 - accuracy: 0.9822 - val_loss: 0.0541 - val_accuracy: 0.9832\n",
      "Epoch 18/20\n",
      "469/469 [==============================] - 121s 258ms/step - loss: 0.0567 - accuracy: 0.9829 - val_loss: 0.0501 - val_accuracy: 0.9836\n",
      "Epoch 19/20\n",
      "469/469 [==============================] - 123s 262ms/step - loss: 0.0543 - accuracy: 0.9837 - val_loss: 0.0468 - val_accuracy: 0.9846\n",
      "Epoch 20/20\n",
      "469/469 [==============================] - 107s 227ms/step - loss: 0.0520 - accuracy: 0.9841 - val_loss: 0.0448 - val_accuracy: 0.9863\n",
      "79/79 [==============================] - 6s 82ms/step - loss: 0.0448 - accuracy: 0.9863\n",
      "\n",
      "Test accuracy: 98.6%\n"
     ]
    }
   ],
   "source": [
    "# classifier loss, Adam optimizer, classifier accuracy\n",
    "model.compile(loss='categorical_crossentropy',\n",
    "              optimizer='sgd',\n",
    "              metrics=['accuracy'])\n",
    "\n",
    "# train the model with input images and labels\n",
    "model.fit([x_train, x_train],\n",
    "          y_train, \n",
    "          validation_data=([x_test, x_test], y_test),\n",
    "          epochs=20,\n",
    "          batch_size=batch_size)\n",
    "\n",
    "# model accuracy on test dataset\n",
    "score = model.evaluate([x_test, x_test], y_test, batch_size=batch_size)\n",
    "print(\"\\nTest accuracy: %.1f%%\" % (100.0 * score[1]))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.7.4"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
